import itertools
import math
from typing import Sequence, Any

import gym.spaces as spaces

from centralized_verification.envs.fast_grid_world import FastGridWorld


class FastGridWorldNearbyObs(FastGridWorld):
    """
    Agents are able to see their current position, and the location of other agents within a obs_radius number of spaces
    The observation for each agent consists of the agent's position,
    and a list of the relative locations of all other agents (or 0 if the other agent is not in the observable area)
    Similarly to the base FastGridWorld, we pre-calculate and cache the offset locations for each index
    """

    def __init__(self, *args, obs_radius: int = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self.obs_radius = obs_radius

        """
        A mapping from offsets to unique indices- for example, with obs_radius 2 produces
        {(-2, -2): 1,
         (-2, -1): 2,
         ...
         (2, 1): 24,
         (2, 2): 25
        """
        coord_offsets = {(xpos, ypos): idx + 1 for idx, (xpos, ypos) in enumerate(
            itertools.product(range(-obs_radius, obs_radius + 1), range(-obs_radius, obs_radius + 1)))}

        # A last coordinate to represent when an agent is not visible
        not_visible_coord = 0

        def get_other_agent_offset_map(this_pos_idx, other_pos_idx):
            (this_coord_x, this_coord_y) = self.grid_posns[this_pos_idx]
            (other_coord_x, other_coord_y) = self.grid_posns[other_pos_idx]
            offset = (other_coord_x - this_coord_x, other_coord_y - this_coord_y)
            if offset in coord_offsets:
                return coord_offsets[offset]
            else:
                return not_visible_coord

        # Given my position and another agent's position,
        # what is their relative position to me (or 0 if they are not visible)
        self.offset_map = [[get_other_agent_offset_map(this_pos_idx, other_pos_idx)
                            for other_pos_idx in range(len(self.grid_posns))] for this_pos_idx in
                           range(len(self.grid_posns))]
        coord_offset_items = list(coord_offsets.items())
        coord_offset_items.sort(key=lambda x: x[1])
        self.coord_offset_inv = [(math.nan, math.nan)] + [item for item, key in coord_offset_items]

        self.cached_specifications = None

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        num_visibility_squares = ((self.obs_radius * 2) + 1) ** 2  # Number of relative positions we can see for others
        obs_for_other_agents = [num_visibility_squares + 1] * (self.num_agents - 1)  # Plus when others are out of range
        whole_obs_for_single_agent = spaces.MultiDiscrete([len(self.grid_posns), *obs_for_other_agents])
        return [whole_obs_for_single_agent] * self.num_agents

    def project_single_obs(self, state, agent_idx):
        my_pos_idx = state[agent_idx]
        other_agents = [x for i, x in enumerate(state) if i != agent_idx]

        my_obs = [my_pos_idx]
        for other_agent_pos_idx in other_agents:
            my_obs.append(self.offset_map[my_pos_idx][other_agent_pos_idx])

        return tuple(my_obs)

    def project_obs(self, state) -> Sequence[Any]:
        return tuple(self.project_single_obs(state, i) for i in range(self.num_agents))

